{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Copyright 2025 Bytedance Ltd. and/or its affiliates.\n",
    "# SPDX-License-Identifier: Apache-2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from copy import deepcopy\n",
    "from typing import (\n",
    "    Any,\n",
    "    AsyncIterable,\n",
    "    Callable,\n",
    "    Dict,\n",
    "    Generator,\n",
    "    List,\n",
    "    NamedTuple,\n",
    "    Optional,\n",
    "    Tuple,\n",
    "    Union,\n",
    ")\n",
    "import requests\n",
    "from io import BytesIO\n",
    "\n",
    "from PIL import Image\n",
    "import torch\n",
    "from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights\n",
    "\n",
    "from data.transforms import ImageTransform\n",
    "from data.data_utils import pil_img2rgb, add_special_tokens\n",
    "from modeling.bagel import (\n",
    "    BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel\n",
    ")\n",
    "from modeling.qwen2 import Qwen2Tokenizer\n",
    "from modeling.bagel.qwen2_navit import NaiveCache\n",
    "from modeling.autoencoder import load_ae\n",
    "from safetensors.torch import load_file"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_path = \"/path/to/BAGEL-7B-MoT/weights\"  # Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT\n",
    "\n",
    "# LLM config preparing\n",
    "llm_config = Qwen2Config.from_json_file(os.path.join(model_path, \"llm_config.json\"))\n",
    "llm_config.qk_norm = True\n",
    "llm_config.tie_word_embeddings = False\n",
    "llm_config.layer_module = \"Qwen2MoTDecoderLayer\"\n",
    "\n",
    "# ViT config preparing\n",
    "vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, \"vit_config.json\"))\n",
    "vit_config.rope = False\n",
    "vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1\n",
    "\n",
    "# VAE loading\n",
    "vae_model, vae_config = load_ae(local_path=os.path.join(model_path, \"ae.safetensors\"))\n",
    "\n",
    "# Bagel config preparing\n",
    "config = BagelConfig(\n",
    "    visual_gen=True,\n",
    "    visual_und=True,\n",
    "    llm_config=llm_config, \n",
    "    vit_config=vit_config,\n",
    "    vae_config=vae_config,\n",
    "    vit_max_num_patch_per_side=70,\n",
    "    connector_act='gelu_pytorch_tanh',\n",
    "    latent_patch_size=2,\n",
    "    max_latent_size=64,\n",
    ")\n",
    "\n",
    "with init_empty_weights():\n",
    "    language_model = Qwen2ForCausalLM(llm_config)\n",
    "    vit_model      = SiglipVisionModel(vit_config)\n",
    "    model          = Bagel(language_model, vit_model, config)\n",
    "    model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)\n",
    "\n",
    "# Tokenizer Preparing\n",
    "tokenizer = Qwen2Tokenizer.from_pretrained(model_path)\n",
    "tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)\n",
    "\n",
    "# Image Transform Preparing\n",
    "vae_transform = ImageTransform(1024, 512, 16)\n",
    "vit_transform = ImageTransform(980, 224, 14)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Loading and Multi GPU Infernece Preparing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "max_mem_per_gpu = \"40GiB\"  # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU.\n",
    "\n",
    "device_map = infer_auto_device_map(\n",
    "    model,\n",
    "    max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},\n",
    "    no_split_module_classes=[\"Bagel\", \"Qwen2MoTDecoderLayer\"],\n",
    ")\n",
    "print(device_map)\n",
    "\n",
    "same_device_modules = [\n",
    "    'language_model.model.embed_tokens',\n",
    "    'time_embedder',\n",
    "    'latent_pos_embed',\n",
    "    'vae2llm',\n",
    "    'llm2vae',\n",
    "    'connector',\n",
    "    'vit_pos_embed'\n",
    "]\n",
    "\n",
    "if torch.cuda.device_count() == 1:\n",
    "    first_device = device_map.get(same_device_modules[0], \"cuda:0\")\n",
    "    for k in same_device_modules:\n",
    "        if k in device_map:\n",
    "            device_map[k] = first_device\n",
    "        else:\n",
    "            device_map[k] = \"cuda:0\"\n",
    "else:\n",
    "    first_device = device_map.get(same_device_modules[0])\n",
    "    for k in same_device_modules:\n",
    "        if k in device_map:\n",
    "            device_map[k] = first_device\n",
    "\n",
    "# Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8\n",
    "model = load_checkpoint_and_dispatch(\n",
    "    model,\n",
    "    checkpoint=os.path.join(model_path, \"ema.safetensors\"),\n",
    "    device_map=device_map,\n",
    "    offload_buffers=True,\n",
    "    dtype=torch.bfloat16,\n",
    "    force_hooks=True,\n",
    "    offload_folder=\"/tmp/offload\"\n",
    ")\n",
    "\n",
    "model = model.eval()\n",
    "print('Model loaded')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inferencer Preparing "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from inferencer import InterleaveInferencer\n",
    "\n",
    "inferencer = InterleaveInferencer(\n",
    "    model=model, \n",
    "    vae_model=vae_model, \n",
    "    tokenizer=tokenizer, \n",
    "    vae_transform=vae_transform, \n",
    "    vit_transform=vit_transform, \n",
    "    new_token_ids=new_token_ids\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "\n",
    "seed = 42\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**About Inference Hyperparameters:**\n",
    "- **`cfg_text_scale`:** Controls how strongly the model follows the text prompt. `1.0` disables text guidance. Typical range: `4.0–8.0`.\n",
    "- **`cfg_image_scale`:** Controls how much the model preserves input image details. `1.0` disables image guidance. Typical range: `1.0–2.0`.\n",
    "- **`cfg_interval`:** Fraction of denoising steps where CFG is applied. Later steps can skip CFG to reduce computation. Typical: `[0.4, 1.0]`.\n",
    "- **`timestep_shift`:** Shifts the distribution of denoising steps. Higher values allocate more steps at the start (affects layout); lower values allocate more at the end (improves details).\n",
    "- **`num_timesteps`:** Total denoising steps. Typical: `50`.\n",
    "- **`cfg_renorm_min`:** Minimum value for CFG-Renorm. `1.0` disables renorm. Typical: `0`.\n",
    "- **`cfg_renorm_type`:** CFG-Renorm method:  \n",
    "  - `global`: Normalize over all tokens and channels (default for T2I).\n",
    "  - `channel`: Normalize across channels for each token.\n",
    "  - `text_channel`: Like `channel`, but only applies to text condition (good for editing, may cause blur).\n",
    "- **If edited images appear blurry, try `global` CFG-Renorm, decrease `cfg_renorm_min` or decrease `cfg_scale`.**\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Image Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "inference_hyper=dict(\n",
    "    cfg_text_scale=4.0,\n",
    "    cfg_img_scale=1.0,\n",
    "    cfg_interval=[0.4, 1.0],\n",
    "    timestep_shift=3.0,\n",
    "    num_timesteps=50,\n",
    "    cfg_renorm_min=0.0,\n",
    "    cfg_renorm_type=\"global\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompt = \"A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere.\"\n",
    "\n",
    "print(prompt)\n",
    "print('-' * 10)\n",
    "output_dict = inferencer(text=prompt, **inference_hyper)\n",
    "display(output_dict['image'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Image Generation with Think"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "inference_hyper=dict(\n",
    "    max_think_token_n=1000,\n",
    "    do_sample=False,\n",
    "    # text_temperature=0.3,\n",
    "    cfg_text_scale=4.0,\n",
    "    cfg_img_scale=1.0,\n",
    "    cfg_interval=[0.4, 1.0],\n",
    "    timestep_shift=3.0,\n",
    "    num_timesteps=50,\n",
    "    cfg_renorm_min=0.0,\n",
    "    cfg_renorm_type=\"global\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompt = 'a car made of small cars'\n",
    "\n",
    "print(prompt)\n",
    "print('-' * 10)\n",
    "output_dict = inferencer(text=prompt, think=True, **inference_hyper)\n",
    "print(output_dict['text'])\n",
    "display(output_dict['image'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Editing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "inference_hyper=dict(\n",
    "    cfg_text_scale=4.0,\n",
    "    cfg_img_scale=2.0,\n",
    "    cfg_interval=[0.0, 1.0],\n",
    "    timestep_shift=3.0,\n",
    "    num_timesteps=50,\n",
    "    cfg_renorm_min=0.0,\n",
    "    cfg_renorm_type=\"text_channel\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "image = Image.open('test_images/women.jpg')\n",
    "prompt = 'She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes.'\n",
    "\n",
    "display(image)\n",
    "print(prompt)\n",
    "print('-'*10)\n",
    "output_dict = inferencer(image=image, text=prompt, **inference_hyper)\n",
    "display(output_dict['image'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Edit with Think"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "inference_hyper=dict(\n",
    "    max_think_token_n=1000,\n",
    "    do_sample=False,\n",
    "    # text_temperature=0.3,\n",
    "    cfg_text_scale=4.0,\n",
    "    cfg_img_scale=2.0,\n",
    "    cfg_interval=[0.0, 1.0],\n",
    "    timestep_shift=3.0,\n",
    "    num_timesteps=50,\n",
    "    cfg_renorm_min=0.0,\n",
    "    cfg_renorm_type=\"text_channel\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "image = Image.open('test_images/octupusy.jpg')\n",
    "prompt = 'Could you display the sculpture that takes after this design?'\n",
    "\n",
    "display(image)\n",
    "print('-'*10)\n",
    "output_dict = inferencer(image=image, text=prompt, think=True, **inference_hyper)\n",
    "print(output_dict['text'])\n",
    "display(output_dict['image'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Understanding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "inference_hyper=dict(\n",
    "    max_think_token_n=1000,\n",
    "    do_sample=False,\n",
    "    # text_temperature=0.3,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "image = Image.open('test_images/meme.jpg')\n",
    "prompt = \"Can someone explain what’s funny about this meme??\"\n",
    "\n",
    "display(image)\n",
    "print(prompt)\n",
    "print('-'*10)\n",
    "output_dict = inferencer(image=image, text=prompt, understanding_output=True, **inference_hyper)\n",
    "print(output_dict['text'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TaylorSeer Support for Inference Acceleration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable TaylorSeer (https://arxiv.org/abs/2503.06923) by setting enable_taylorseer=True in inference_hyper\n",
    "# This feature is supported for tasks that involve image output, including image generation, image generation with think, editing, and edit with think.\n",
    "\n",
    "# Below is an example of image generation with TaylorSeer enabled.\n",
    "inference_hyper=dict(\n",
    "    cfg_text_scale=4.0,\n",
    "    cfg_img_scale=1.0,\n",
    "    cfg_interval=[0.4, 1.0],\n",
    "    timestep_shift=3.0,\n",
    "    num_timesteps=50,\n",
    "    cfg_renorm_min=0.0,\n",
    "    cfg_renorm_type=\"global\",\n",
    "    enable_taylorseer=True,  # Enable TaylorSeer\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere.\"\n",
    "\n",
    "print(prompt)\n",
    "print('-' * 10)\n",
    "output_dict = inferencer(text=prompt, **inference_hyper)\n",
    "display(output_dict['image'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "fileId": "1bfaa82d-51b0-4c13-9e4c-295ba28bcd8a",
  "filePath": "/mnt/bn/seed-aws-va/chaorui/code/cdt-hf/notebooks/chat.ipynb",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
